/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "ascir/Dialect/Asc/IR/Asc.h"
#include "ascir/Dialect/Asc/Transforms/Passes.h"
#include "ascir/Dialect/Asc/Utils/Utils.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace ascendc {
#define GEN_PASS_DEF_INSERTSYNC
#include "ascir/Dialect/Asc/Transforms/Passes.h.inc"
}  // namespace ascendc
}  // namespace mlir

using namespace mlir;

namespace {

Value findQueue(Value tensor) {
  if (auto op = tensor.getDefiningOp<ascendc::TQueBindAllocTensorOp>()) {
    return op.getQueue();
  }
  if (auto op = tensor.getDefiningOp<ascendc::TQueBindDequeTensorOp>()) {
    return op.getQueue();
  }
  return nullptr;
}

void enqueueTensors(func::FuncOp funcOp) {
  funcOp.walk([](ascendc::OpWithDst op) {
    auto tensor = op.getDst();
    if (!tensor || !isa<ascendc::BaseTensorType>(tensor.getType()) ||
        isa_and_present<ascendc::GlobalTensorOp>(tensor.getDefiningOp()))
      return;
    OpBuilder builder(op.getContext());
    builder.setInsertionPointAfter(op);
    if (auto queue = findQueue(tensor)) {
      builder.create<ascendc::TQueBindEnqueTensorOp>(op.getLoc(), queue, tensor);
      return;
    }
    builder.create<ascendc::PipeBarrierOp>(op.getLoc(), ascendc::Pipe::PIPE_V);
  });
}

void createSetGetValueSync(bool isBefore, OpBuilder &builder, Location loc) {
  ascendc::HardEvent currentEvent =
      isBefore ? ascendc::HardEvent::V_S : ascendc::HardEvent::S_V;
  Value pipe = builder.create<ascendc::PipeOp>(loc);
  auto eventId = builder
                     .create<ascendc::TPipeFetchEventIDOp>(loc, builder.getI8Type(), pipe,
                                                           currentEvent)
                     .getResult();
  builder.create<ascendc::SetFlagOp>(loc, currentEvent, eventId);
  builder.create<ascendc::WaitFlagOp>(loc, currentEvent, eventId);
}

void syncGetValueOp(func::FuncOp &funcOp) {
  funcOp.walk([](ascendc::LocalTensorGetValueOp op) {
    auto loc = op.getLoc();
    OpBuilder builder(op);
    createSetGetValueSync(true, builder, loc);
    builder.setInsertionPointAfter(op);
    createSetGetValueSync(false, builder, loc);
  });
}

void syncSetValueOp(func::FuncOp &funcOp) {
  funcOp.walk([](ascendc::LocalTensorSetValueOp op) {
    auto loc = op.getLoc();
    OpBuilder builder(op);
    if (auto forOp = op->getParentOfType<scf::ForOp>()) {
      constexpr unsigned oneOpOneYield = 2U;
      if (forOp.getBody()->getOperations().size() == oneOpOneYield) {
        builder.setInsertionPoint(forOp);
        createSetGetValueSync(true, builder, loc);
        builder.setInsertionPointAfter(forOp);
        createSetGetValueSync(false, builder, loc);
        return;
      }
    }
    createSetGetValueSync(true, builder, loc);
    builder.setInsertionPointAfter(op);
    createSetGetValueSync(false, builder, loc);
  });
}

bool reEnque(OpBuilder &b, Location loc, ascendc::TQueBindEnqueTensorOp enq,
             ascendc::TQueBindDequeTensorOp deq) {
  DominanceInfo di;
  if (!di.dominates(enq, deq)) {
    auto *enqParent = deq->getParentRegion()->findAncestorOpInRegion(*enq);
    if (!enqParent) {
      enq.emitOpError("failed to be hoisted to tensor deque op scope");
      return false;
    }
    b.setInsertionPointAfter(enqParent);
    b.create<ascendc::TQueBindEnqueTensorOp>(loc, enq.getQueue(), enq.getTensor());
    enq.erase();
  }
  return true;
}

bool dequeueTensors(Region &region) {
  DominanceInfo di;
  bool res = true;
  for (Block &block : region) {
    for (Operation &op : llvm::make_early_inc_range(block)) {
      auto enq = dyn_cast<ascendc::TQueBindEnqueTensorOp>(op);
      if (!enq) {
        for (Region &inner : op.getRegions()) {
          res &= dequeueTensors(inner);
        }
        continue;
      }
      auto tensor = enq.getTensor();
      std::vector<Operation *> users;
      llvm::copy_if(tensor.getUsers(), std::back_inserter(users),
                    [&](Operation *user) {
                      return !isa<ascendc::TQueBindFreeTensorOp>(user) &&
                              ascendc::opPrecedes(enq, user, di);
                    });
      if (users.empty()) {
        continue;
      }
      Operation *firstUser = *std::min_element(
        users.begin(), users.end(), [&](Operation *lhs, Operation *rhs) {
          return ascendc::opPrecedes(lhs, rhs, di);
        });
      auto *userInSameRegion =
          enq->getParentRegion()->findAncestorOpInRegion(*firstUser);
      if (userInSameRegion) {
        firstUser = userInSameRegion;
      }
      OpBuilder builder(firstUser);
      auto deq = builder.create<ascendc::TQueBindDequeTensorOp>(enq.getLoc(),
          tensor.getType(), enq.getQueue());
      if (!reEnque(builder, op.getLoc(), enq, deq)) {
        return false;
      }
      tensor.replaceUsesWithIf(deq.getTensor(), [&](OpOperand &opnd) {
        auto *owner = opnd.getOwner();
        return llvm::is_contained(users, owner);
      });
    }
  }
  return res;
}

void canonicalizeBarriers(func::FuncOp funcOp) {
  auto builder = OpBuilder::atBlockTerminator(&funcOp.getFunctionBody().back());
  builder.create<ascendc::PipeBarrierOp>(builder.getUnknownLoc(), ascendc::Pipe::PIPE_ALL);
  RewritePatternSet patterns(funcOp.getContext());
  ascendc::PipeBarrierOp::getCanonicalizationPatterns(patterns, funcOp.getContext());
  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

struct InsertSyncPass : public ascendc::impl::InsertSyncBase<InsertSyncPass> {
public:
  void runOnOperation() override {
    func::FuncOp funcOp = getOperation();
    if (funcOp.isDeclaration()) {
      return;
    }
    enqueueTensors(funcOp);
    if (!dequeueTensors(funcOp.getRegion())) {
      signalPassFailure();
      return;
    }
    syncGetValueOp(funcOp);
    syncSetValueOp(funcOp);
    canonicalizeBarriers(funcOp);
  }
};

}  // namespace

namespace mlir {
namespace ascendc {
std::unique_ptr<Pass> createInsertSyncPass() {
  return std::make_unique<InsertSyncPass>();
}
}  // namespace ascendc
}  // namespace mlir
